package com.alatus.djl.app;

import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;

/**
 * author: Alatus
 * date: 2025/2/2
 * email: alatuslee@qq.com
 * description:使用DJL训练大模型
 */
@RestController
@Slf4j
public class DJL {
//    测试模型
    @GetMapping("/predict")
    public String predict() throws IOException, MalformedModelException, TranslateException {
//        搞个图片先,准备测试数据
        Image image = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");

//        加载模型
        Path path = Paths.get("build/mlp");
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[]{128,64}));
        model.load(path);

//        预测(给模型一个新的输入,让它来判断我们的输入内容)

//        获取一个转换器
        ImageClassificationTranslator build = ImageClassificationTranslator.builder()
                .addTransform(new ToTensor())//转换器
                .addTransform(new Resize(28, 28))//设置图片尺寸
                .build();
//        获取预测器
        Predictor<Image, Classifications> predictor = model.newPredictor(build);
//        预测图片分类
        Classifications predict = predictor.predict(image);
        log.info(predict.toString());
        return predict.toString();
    }

    @GetMapping("/predictPic")
    public String predictImage() throws IOException, MalformedModelException, TranslateException {
        InputStream imageStream = getClass().getClassLoader().getResourceAsStream("static/3.png");
        Image image = null;
        if (imageStream == null) {
            // 处理图片没有找到的情况
            System.out.println("Image not found!");
        } else {
            image = ImageFactory.getInstance().fromInputStream(imageStream);
        }
        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[]{128,64}));
        model.load(Paths.get("build/mlp"));
        ImageClassificationTranslator build = ImageClassificationTranslator.builder().addTransform(new Resize(28, 28))
                .addTransform(new CenterCrop())//中心裁剪
                .addTransform(new ToTensor())
                .build();
        Predictor<Image, Classifications> predictor = model.newPredictor(build);
        Classifications predict = predictor.predict(image);
        log.info(predict.toString());
        return predict.toString();
    }

    @GetMapping("/fullModel")
    public String fullModel() throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
//        完全训练一个模型
//        准备数据集(这里我们用的是官方自带的),用自己的数据集就自定义DataSet
        RandomAccessDataset trainDataset = getDataset(Dataset.Usage.TRAIN);
        RandomAccessDataset validationDataset = getDataset(Dataset.Usage.TEST);

//        自定义数据集的例子,这里的set方法的都是必须要填的
//        ImageFolder build = ImageFolder.builder()
//                .addTransform()//添加转换器
//                .optImageSize()//设置图片尺寸大小
//                .optImageWidth()
//                .optImageHeight()
//                .setSampling(64,true)//设置采样信息,一次64张图片,随机采样
//                .setRepositoryPath()
//                .build();//设置数据集的存储路径
//        build.getData()//获取数据集就得到了

//        构建神经网络,这边我们采用直接用多层感知机的方式,而不是Block块的方式
//        因为这个案例是自带的,我们直接使用Mnist的参数即可
//        这里的神经网路MLP本身就是一个Block块:public class Mlp extends SequentialBlock
        Mlp mlp = new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES,new int[]{128,64});
//        这里我们的层数如果太多了,他的精度会比较高,但也会导致一些过拟合的情况存在
//        过拟合就是过度敏感了,图片里面一点点接近的信息他都认为是它正确的
//        比如说照片里面有一个鸟,但是鸟的图片和某一个他认识的东西近似,他就认为是它认为的那个东西然后会认为它是正确的
//        这就叫过拟合
//        欠拟合就是训练量太少了,他认不出来,啥东西他都认为是自行车

        //        构建模型(模型应用上面的神经网络)
        try(Model model = Model.newInstance("mlp")){
//            设置神经网络
            model.setBlock(mlp);
//            接下来训练这个神经网络,配置了损失函数,精度,训练监听器
            String output = "build/mlp";

//            训练配置信息
            DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                    .addEvaluator(new Accuracy())
                    .addTrainingListeners(TrainingListener.Defaults.logging(output));


            //        训练配置(训练集)
//            基于给到的训练配置信息开始训练
            try(Trainer trainer = model.newTrainer(config)){
//                查看训练期间的详细指标数据
                trainer.setMetrics(new Metrics());
//                初始化训练器
                trainer.initialize(new Shape(1, Mnist.IMAGE_HEIGHT, Mnist.IMAGE_WIDTH));
//                接下来做拟合(拟合也就是训练)
//                trainer.fit(trainDataset, 10);
//                这里我们用EasyTrain来训练
//                训练5次
                EasyTrain.fit(trainer,17, trainDataset, validationDataset);
                TrainingResult result = trainer.getTrainingResult();
                log.info("训练结果:"+result.toString());

                //        保存模型
                model.save(Paths.get(output),"mlp");
                return "模型训练完成,并保存成功";
            }
        }
    }

    private RandomAccessDataset getDataset(Dataset.Usage usage) throws IOException {
        Mnist build = Mnist.builder().setSampling(64, true)//设置采样信息
                .optUsage(usage)
                .optLimit(64)
                .build();
//        弄个进度条
        build.prepare(new ProgressBar());
        return build;
    }


    //    机器学习最基本的,我们要把我们需要处理的数据给转为一个N维向量
//    只有转为一个向量了,才能继续向下处理
    @GetMapping("/test01")
    public String test01() {
        try(NDManager manager = NDManager.newBaseManager()){
//            我们通过这个manager创建向量
//            这里的Shape就是N维数组的形状
//            我们这里创建的是一个2乘以3的矩阵(N维向量)
//            这里的ones指的是内容都是1填充的
//            输出的1.代表这是一个1的float值
            NDArray ones = manager.ones(new Shape(2, 3));
            log.info(ones.toString());
//            这里,我们同样可以自己创建一个矩阵
//            通过创建对应的数组和给予我们需要的形状来创建一个矩阵
            NDArray array = manager.create(new float[]{1.14F, 5.14F, 1.9F, 1.9F, 8.10F, 1.14f}, new Shape(2, 3));
            log.info(array.toString());
//            矩阵计算
//            如矩阵转质,这里我们的矩阵二乘三的矩阵就变成了三乘二
            NDArray transpose = array.transpose();
            log.info(transpose.toString());
            return "矩阵"+ones+"和"+array+"的转置为"+transpose;
        }
    }
//    我们这里的这些矩阵你可以模拟为从数据集中加载的
//    数据集是用于训练机器学习模型的数据集合
//    机器学习通常使用三个数据集,训练集,验证集和测试集

//    训练集是我们用来训练的实际数据集,模型从这些数据中学习权重和参数

//    验证集用来在训练过程中评估给定模型,它帮助机器学习工程师在模型开发阶段微调超参数
//    模型不从验证数据集学习,验证数据集是可选的

//    测试数据集提供了用于评估模型性能的黄金标准,它只在模型完全训练完成后使用
//    测试数据集应该更准确的评估模型将如何在新数据上执行

//    当我们有了数据集以后
//    数据集加载为N维向量,我们需要通过Translator来转换数据集

    @GetMapping("/test02")
    public String test02() {
//        输入的图片像素
        long inputSize = 28 * 28;
//        输出的图片类型
        long outputSize = 10;
//        整一个批量扁平块,把二维图像输入转为一维特征向量
        SequentialBlock block = new SequentialBlock();
//        添加扁平块
        block.add(Blocks.batchFlattenBlock(inputSize));
//        添加一个隐藏层,线性变化大小为128
        block.add(Linear.builder().setUnits(128).build());
//        添加相应的激活函数
        block.add(Activation::relu);

//        第二个隐藏层的激活函数,这一层是大小为64的变化
        block.add(Linear.builder().setUnits(64).build());
        block.add(Activation::relu);

//        我试着添加一个32的隐藏层激活函数
        block.add(Linear.builder().setUnits(32).build());
        block.add(Activation::relu);

//        最后输出10大小的特征向量
        block.add(Linear.builder().setUnits(outputSize).build());

//        这些大小是在实验过程中选择的
//        围绕块,可以构建我们的模型,添加一些重要的元数据,如可以在训练和推理时使用的超参数
        Model model = Model.newInstance("mlp");
        model.setBlock(block);
//        现在就拥有了块和模型了,剩下的就是如何进行训练的部分
        return "构建块和模型";
    }

//    因此本质上,我们的模型训练就是,我们构建一个函数量,再经由我们创建的这个模型
//    模型内部使用的就是我们配置的激活函数
//    一层一层训练,直到最后精度损失控制到一定程度,停止训练
//    然后我们就可以使用这个模型进行预测了
//    模型的工作原理就是,由一个N维数组经过训练和优化,变成一个N-1维数组(另外一个N维数组)
//    也可以说,所谓的模型就是对我们的Block量的一个封装

//    也就是我们自己配置和封装对应的算法和输入输出,得到我们需要的模型,接下来就是使用这个模型进行训练,让它的精度损失控制到我们需要的水平
//    再经过正向传播,反向传播的多轮训练,直到精度损失控制到我们想要的水平

//    除此以外就是需要把我们的数据或者说训练内容进行转换,使用Translator,得到对应的N维数组
    @GetMapping("/test03")
    public String test03() {
//        这里我们使用Translator进行数据预处理的主要目的是,解决训练数据的格式不一致的问题
//        毕竟输入的数据集和模型训练的维度不一定完全一致,所以需要将我们的数据集进行预处理,使其符合训练的维度
//        比如说这里我们如果不是28*28的图像,但是模型训练的维度是28*28,所以需要将我们的数据集进行预处理,使其符合训练的维度
        ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder().addTransform(new CenterCrop())//中心裁剪
                .addTransform(new Resize(28, 28))//调整尺寸
                .addTransform(new ToTensor())//将图像N维数组从预处理的格式转换为神经网络格式的变换
                .build();
        return "参数预处理";
    }

//    最终达到的效果就是,把输入数据,不管是图片视频文字语言还是任何其他的东西转为机器能识别的数据量或者说数字量
//    更直接就是N维数组,将输出的内容转为人类可以识别的内容,和对应的概率
//    数字化,数字化,任何东西严格意义上都可以用数字来表示,只要你有办法去表示,那么你就可以用数字来表示

    @GetMapping("/test04")
    public String test04() throws TranslateException, ModelNotFoundException, MalformedModelException, IOException {
        ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder().addTransform(new CenterCrop())//中心裁剪
                .addTransform(new Resize(28, 28))//调整尺寸
                .addTransform(new ToTensor())//将图像N维数组从预处理的格式转换为神经网络格式的变换
                .build();
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .optFilter("layer","50")
                .optTranslator(classificationTranslator)
                .optProgress(new ProgressBar())
                .build();
        ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
        log.info(model.getName());
        Predictor<Image, Classifications> predictor = model.newPredictor();
//        这里我们传递一个图片给他去做预测
//        假装有这么一个图片
        Classifications predict = predictor.predict(null);
        return "使用模型";
    }
}
